51bb96
@@ -23,6 +23,7 @@
import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Supplier;
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
@@ -219,11 +220,7 @@
public class ForwardedHeaderFilter extends OncePerRequestFilter {
 
 		private final int port;
 
-		private final String contextPath;
-
-		private final String requestUri;
-
-		private final String requestUrl;
+		private final ForwardedPrefixExtractor forwardedPrefixExtractor;
 
 
 		ForwardedHeaderExtractingRequest(HttpServletRequest request, UrlPathHelper pathHelper) {
@@ -238,28 +235,9 @@
public class ForwardedHeaderFilter extends OncePerRequestFilter {
 			this.host = uriComponents.getHost();
 			this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
 
-			String prefix = getForwardedPrefix(request);
-			this.contextPath = (prefix != null ? prefix : request.getContextPath());
-			this.requestUri = this.contextPath + pathHelper.getPathWithinApplication(request);
-			this.requestUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port) + this.requestUri;
-		}
-
-		@Nullable
-		private static String getForwardedPrefix(HttpServletRequest request) {
-			String prefix = null;
-			Enumeration<String> names = request.getHeaderNames();
-			while (names.hasMoreElements()) {
-				String name = names.nextElement();
-				if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) {
-					prefix = request.getHeader(name);
-				}
-			}
-			if (prefix != null) {
-				while (prefix.endsWith("/")) {
-					prefix = prefix.substring(0, prefix.length() - 1);
-				}
-			}
-			return prefix;
+			String baseUrl = this.scheme + "://" + this.host + (port == -1 ? "" : ":" + port);
+			Supplier<HttpServletRequest> delegateRequest = () -> (HttpServletRequest) getRequest();
+			this.forwardedPrefixExtractor = new ForwardedPrefixExtractor(delegateRequest, pathHelper, baseUrl);
 		}
 
 
@@ -287,18 +265,122 @@
public class ForwardedHeaderFilter extends OncePerRequestFilter {
 
 		@Override
 		public String getContextPath() {
-			return this.contextPath;
+			return this.forwardedPrefixExtractor.getContextPath();
 		}
 
 		@Override
 		public String getRequestURI() {
-			return this.requestUri;
+			return this.forwardedPrefixExtractor.getRequestUri();
 		}
 
 		@Override
 		public StringBuffer getRequestURL() {
+			return this.forwardedPrefixExtractor.getRequestUrl();
+		}
+	}
+
+
+	/**
+	 * Responsible for the contextPath, requestURI, and requestURL with forwarded
+	 * headers in mind, and also taking into account changes to the path of the
+	 * underlying delegate request (e.g. on a Servlet FORWARD).
+	 */
+	private static class ForwardedPrefixExtractor {
+
+		private final Supplier<HttpServletRequest> delegate;
+
+		private final UrlPathHelper pathHelper;
+
+		private final String baseUrl;
+
+		private String actualRequestUri;
+
+		@Nullable
+		private final String forwardedPrefix;
+
+		@Nullable
+		private String requestUri;
+
+		private String requestUrl;
+
+
+		/**
+		 * Constructor with required information.
+		 * @param delegateRequest supplier for the current
+		 * {@link HttpServletRequestWrapper#getRequest() delegate request} which
+		 * may change during a forward (e.g. Tocat.
+		 * @param pathHelper the path helper instance
+		 * @param baseUrl the host, scheme, and port based on forwarded headers
+		 */
+		public ForwardedPrefixExtractor(
+				Supplier<HttpServletRequest> delegateRequest, UrlPathHelper pathHelper, String baseUrl) {
+
+			this.delegate = delegateRequest;
+			this.pathHelper = pathHelper;
+			this.baseUrl = baseUrl;
+			this.actualRequestUri = delegateRequest.get().getRequestURI();
+
+			this.forwardedPrefix = initForwardedPrefix(delegateRequest.get());
+			this.requestUri = initRequestUri();
+			this.requestUrl = initRequestUrl(); // Keep the order: depends on requestUri
+		}
+
+		@Nullable
+		private static String initForwardedPrefix(HttpServletRequest request) {
+			String result = null;
+			Enumeration<String> names = request.getHeaderNames();
+			while (names.hasMoreElements()) {
+				String name = names.nextElement();
+				if ("X-Forwarded-Prefix".equalsIgnoreCase(name)) {
+					result = request.getHeader(name);
+				}
+			}
+			if (result != null) {
+				while (result.endsWith("/")) {
+					result = result.substring(0, result.length() - 1);
+				}
+			}
+			return result;
+		}
+
+		@Nullable
+		private String initRequestUri() {
+			if (this.forwardedPrefix != null) {
+				return this.forwardedPrefix + this.pathHelper.getPathWithinApplication(this.delegate.get());
+			}
+			return null;
+		}
+
+		private String initRequestUrl() {
+			return this.baseUrl + (this.requestUri != null ? this.requestUri : this.delegate.get().getRequestURI());
+		}
+
+
+		public String getContextPath() {
+			return this.forwardedPrefix == null ? this.delegate.get().getContextPath() : this.forwardedPrefix;
+		}
+
+		public String getRequestUri() {
+			if (this.requestUri == null) {
+				return this.delegate.get().getRequestURI();
+			}
+			recalculatePathsIfNecesary();
+			return this.requestUri;
+		}
+
+		public StringBuffer getRequestUrl() {
+			recalculatePathsIfNecesary();
 			return new StringBuffer(this.requestUrl);
 		}
+
+		private void recalculatePathsIfNecesary() {
+			if (!this.actualRequestUri.equals(this.delegate.get().getRequestURI())) {
+				// Underlying path change (e.g. Servlet FORWARD).
+				this.actualRequestUri = this.delegate.get().getRequestURI();
+				this.requestUri = initRequestUri();
+				this.requestUrl = initRequestUrl(); // Keep the order: depends on requestUri
+			}
+		}
 	}
 
 
